# -*- coding: utf-8 -*- #
"""
Time                2023/3/19 20:45
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                redis_service_discovery.py
Description:
"""
import time
from typing import List, Dict
from cmg_base import ServiceDiscovery, ServiceInstance, CmgUrl, \
    LoadBalancingStrategy, create_load_balancing_strategy, \
    LoadBalancingStrategyBase
from .common import ClientBase, StaticConst


class RedisServiceDiscovery(ServiceDiscovery, ClientBase):
    """A redis-based execution module implement of ServiceDiscovery
    """

    def __init__(self, url: CmgUrl, fetch_interval: int = 5):
        ServiceDiscovery.__init__(self, fetch_interval)
        ClientBase.__init__(self, url)

        # Handles Redis implementation of event-centric specialization

        self._auto_fetch_thread = None
        self._services = {

        }
        self._strategy_map: Dict[str, LoadBalancingStrategyBase] = {

        }
        self._last_fetch_time_map = {

        }

    def get_services(self) -> List[str]:
        return [StaticConst.get_origin_service_name(inner_service_name) for
                inner_service_name in
                self.redis_client.smembers(StaticConst.CMG_REDIS_SERVICES)]

    def _get_instances(self, service_name: str) -> List[ServiceInstance]:
        inner_service_name = StaticConst.get_inner_service_name(service_name)
        service_ids = self.redis_client.smembers(inner_service_name)
        pl = self.redis_client.pipeline()
        for service_id in service_ids:
            pl.hgetall(service_id)
        return [ServiceInstance.from_redis_mapping(item) for item in
                pl.execute()]

    def get_instances(self, service_name: str, force: bool = False) \
            -> List[ServiceInstance]:
        if force or (service_name not in self._services) or \
                (service_name not in self._last_fetch_time_map) or \
                (self._last_fetch_time_map[
                     service_name] + self.fetch_interval < time.time()):
            # Perform fetch in the following two situations
            # 1. First fetch instances for specific service_name
            # 2. The last fetch time is more than *fetch_interval* seconds
            #    away.
            instances = self._get_instances(service_name)
            self._services[service_name] = instances
            self._last_fetch_time_map[service_name] = time.time()
            if service_name in self._strategy_map:
                self._strategy_map[service_name].update(instances)
        return self._services[service_name]

    def get_instance(self, service_name: str, strategy: LoadBalancingStrategy,
                     force: bool = False) \
            -> ServiceInstance:
        if service_name not in self._strategy_map or \
                self._strategy_map[service_name].type() != strategy:
            self._strategy_map[service_name] = create_load_balancing_strategy(
                strategy
            )
        # Update strategy if possible
        self.get_instances(service_name, force)
        return self._strategy_map[service_name].select()
